package com.lw.leetcode.tree.b;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;

/**
 * Created with IntelliJ IDEA.
 * 1786. 从第一个节点出发到最后一个节点的受限路径数
 *
 * @author liw
 * @version 1.0
 * @date 2022/8/18 9:44
 */
public class CountRestrictedPaths {

    public static void main(String[] args) {
        CountRestrictedPaths test = new CountRestrictedPaths();

        // 3
        // [0, 5, 3, 2, 7, 1]
//        int n = 5;
//        int[][] edges = {{1, 2, 3}, {1, 3, 3}, {2, 3, 1}, {1, 4, 2}, {5, 2, 2}, {3, 5, 1}, {5, 4, 10}};

        // 1
        int n = 7;
        int[][] edges = {{1, 3, 1}, {4, 1, 2}, {7, 3, 4}, {2, 5, 3}, {5, 6, 1}, {6, 7, 2}, {7, 5, 3}, {2, 6, 4}};

        int i = test.countRestrictedPaths(n, edges);
        System.out.println(i);
    }

    public int countRestrictedPaths(int n, int[][] edges) {
        List<Long>[] arr = new ArrayList[n + 1];
        int[] counts = new int[n + 1];
        int[] flags = new int[n + 1];
        counts[n] = 10;
        for (int i = 1; i <= n; i++) {
            arr[i] = new ArrayList<>();
        }
        for (int[] edge : edges) {
            int f = edge[0];
            int t = edge[1];
            long w = edge[2];
            arr[f].add((w << 32) + t);
            arr[t].add((w << 32) + f);
        }
        PriorityQueue<Long> queue = new PriorityQueue<>();
        queue.add((1L << 32) + n);
        while (!queue.isEmpty()) {
            Long poll = queue.poll();
            int f = poll.intValue();
            long sum = poll >> 32;
            if (counts[f] < sum || flags[f] == 1) {
                continue;
            }
            flags[f] = 1;
            counts[f] = (int) sum;
            for (Long v : arr[f]) {
                long w = (v >> 32) + sum;
                int t = v.intValue();
                if (counts[t] == 0 || counts[t] > w) {
                    counts[t] = (int) w;
                    queue.add((w << 32) + t);
                }
            }
        }
        queue = new PriorityQueue<>((a, b) -> Long.compare(b, a));
        queue.add((((long) counts[1]) << 32) + 1);
        long[] sums = new long[n + 1];
        sums[1] = 1;
        Arrays.fill(flags, 0);
        while (!queue.isEmpty()) {
            Long poll = queue.poll();
            int f = poll.intValue();
            if (flags[f] == 1) {
                continue;
            }
            flags[f] = 1;
            long w = poll >> 32;
            for (Long v : arr[f]) {
                int t = v.intValue();
                if (counts[t] < w) {
                    queue.add((((long) counts[t]) << 32) + t);
                    sums[t] = (sums[t] + sums[f]) % 1000000007;
                }
            }
        }
        return (int) sums[n];
    }

}
